{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.autoformer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Autoformer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Autoformer model tackles the challenge of finding reliable dependencies on intricate temporal patterns of long-horizon forecasting.\n",
    "\n",
    "The architecture has the following distinctive features:\n",
    "- In-built progressive decomposition in trend and seasonal compontents based on a moving average filter.\n",
    "- Auto-Correlation mechanism that discovers the period-based dependencies by\n",
    "calculating the autocorrelation and aggregating similar sub-series based on the periodicity.\n",
    "- Classic encoder-decoder proposed by Vaswani et al. (2017) with a multi-head attention mechanism.\n",
    "\n",
    "The Autoformer model utilizes a three-component approach to define its embedding:\n",
    "- It employs encoded autoregressive features obtained from a convolution network.\n",
    "- Absolute positional embeddings obtained from calendar features are utilized."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**References**<br>\n",
    "- [Wu, Haixu, Jiehui Xu, Jianmin Wang, and Mingsheng Long. \"Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting\"](https://proceedings.neurips.cc/paper/2021/hash/bcc0d400288793e8bdcd7c19a8ac0c2b-Abstract.html)<br>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. Autoformer Architecture.](imgs_models/autoformer.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import math\n",
    "import numpy as np\n",
    "from typing import Optional\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from neuralforecast.common._modules import DataEmbedding\n",
    "from neuralforecast.common._base_windows import BaseWindows\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Auxiliary Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class AutoCorrelation(nn.Module):\n",
    "    \"\"\"\n",
    "    AutoCorrelation Mechanism with the following two phases:\n",
    "    (1) period-based dependencies discovery\n",
    "    (2) time delay aggregation\n",
    "    This block can replace the self-attention family mechanism seamlessly.\n",
    "    \"\"\"\n",
    "    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):\n",
    "        super(AutoCorrelation, self).__init__()\n",
    "        self.factor = factor\n",
    "        self.scale = scale\n",
    "        self.mask_flag = mask_flag\n",
    "        self.output_attention = output_attention\n",
    "        self.dropout = nn.Dropout(attention_dropout)\n",
    "\n",
    "    def time_delay_agg_training(self, values, corr):\n",
    "        \"\"\"\n",
    "        SpeedUp version of Autocorrelation (a batch-normalization style design)\n",
    "        This is for the training phase.\n",
    "        \"\"\"\n",
    "        head = values.shape[1]\n",
    "        channel = values.shape[2]\n",
    "        length = values.shape[3]\n",
    "        # find top k\n",
    "        top_k = int(self.factor * math.log(length))\n",
    "        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)\n",
    "        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]\n",
    "        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)\n",
    "        # update corr\n",
    "        tmp_corr = torch.softmax(weights, dim=-1)\n",
    "        # aggregation\n",
    "        tmp_values = values\n",
    "        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)\n",
    "        for i in range(top_k):\n",
    "            pattern = torch.roll(tmp_values, -int(index[i]), -1)\n",
    "            delays_agg = delays_agg + pattern * \\\n",
    "                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))\n",
    "        return delays_agg\n",
    "\n",
    "    def time_delay_agg_inference(self, values, corr):\n",
    "        \"\"\"\n",
    "        SpeedUp version of Autocorrelation (a batch-normalization style design)\n",
    "        This is for the inference phase.\n",
    "        \"\"\"\n",
    "        batch = values.shape[0]\n",
    "        head = values.shape[1]\n",
    "        channel = values.shape[2]\n",
    "        length = values.shape[3]\n",
    "        # index init\n",
    "        init_index = torch.arange(length, device=values.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1)\n",
    "        # find top k\n",
    "        top_k = int(self.factor * math.log(length))\n",
    "        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)\n",
    "        weights = torch.topk(mean_value, top_k, dim=-1)[0]\n",
    "        delay = torch.topk(mean_value, top_k, dim=-1)[1]\n",
    "        # update corr\n",
    "        tmp_corr = torch.softmax(weights, dim=-1)\n",
    "        # aggregation\n",
    "        tmp_values = values.repeat(1, 1, 1, 2)\n",
    "        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)\n",
    "        for i in range(top_k):\n",
    "            tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)\n",
    "            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)\n",
    "            delays_agg = delays_agg + pattern * \\\n",
    "                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))\n",
    "        return delays_agg\n",
    "\n",
    "    def time_delay_agg_full(self, values, corr):\n",
    "        \"\"\"\n",
    "        Standard version of Autocorrelation\n",
    "        \"\"\"\n",
    "        batch = values.shape[0]\n",
    "        head = values.shape[1]\n",
    "        channel = values.shape[2]\n",
    "        length = values.shape[3]\n",
    "        # index init\n",
    "        init_index = torch.arange(length, device=values.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1)\n",
    "        # find top k\n",
    "        top_k = int(self.factor * math.log(length))\n",
    "        weights = torch.topk(corr, top_k, dim=-1)[0]\n",
    "        delay = torch.topk(corr, top_k, dim=-1)[1]\n",
    "        # update corr\n",
    "        tmp_corr = torch.softmax(weights, dim=-1)\n",
    "        # aggregation\n",
    "        tmp_values = values.repeat(1, 1, 1, 2)\n",
    "        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)\n",
    "        for i in range(top_k):\n",
    "            tmp_delay = init_index + delay[..., i].unsqueeze(-1)\n",
    "            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)\n",
    "            delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))\n",
    "        return delays_agg\n",
    "\n",
    "    def forward(self, queries, keys, values, attn_mask):\n",
    "        B, L, H, E = queries.shape\n",
    "        _, S, _, D = values.shape\n",
    "        if L > S:\n",
    "            zeros = torch.zeros_like(queries[:, :(L - S), :], dtype=torch.float, device=queries.device)\n",
    "            values = torch.cat([values, zeros], dim=1)\n",
    "            keys = torch.cat([keys, zeros], dim=1)\n",
    "        else:\n",
    "            values = values[:, :L, :, :]\n",
    "            keys = keys[:, :L, :, :]\n",
    "\n",
    "        # period-based dependencies\n",
    "        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)\n",
    "        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)\n",
    "        res = q_fft * torch.conj(k_fft)\n",
    "        corr = torch.fft.irfft(res, dim=-1)\n",
    "\n",
    "        # time delay agg\n",
    "        if self.training:\n",
    "            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)\n",
    "        else:\n",
    "            V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)\n",
    "\n",
    "        if self.output_attention:\n",
    "            return (V.contiguous(), corr.permute(0, 3, 1, 2))\n",
    "        else:\n",
    "            return (V.contiguous(), None)\n",
    "\n",
    "\n",
    "class AutoCorrelationLayer(nn.Module):\n",
    "    def __init__(self, correlation, hidden_size, n_head, d_keys=None,\n",
    "                 d_values=None):\n",
    "        super(AutoCorrelationLayer, self).__init__()\n",
    "\n",
    "        d_keys = d_keys or (hidden_size // n_head)\n",
    "        d_values = d_values or (hidden_size // n_head)\n",
    "\n",
    "        self.inner_correlation = correlation\n",
    "        self.query_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
    "        self.key_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
    "        self.value_projection = nn.Linear(hidden_size, d_values * n_head)\n",
    "        self.out_projection = nn.Linear(d_values * n_head, hidden_size)\n",
    "        self.n_head = n_head\n",
    "\n",
    "    def forward(self, queries, keys, values, attn_mask):\n",
    "        B, L, _ = queries.shape\n",
    "        _, S, _ = keys.shape\n",
    "        H = self.n_head\n",
    "\n",
    "        queries = self.query_projection(queries).view(B, L, H, -1)\n",
    "        keys = self.key_projection(keys).view(B, S, H, -1)\n",
    "        values = self.value_projection(values).view(B, S, H, -1)\n",
    "\n",
    "        out, attn = self.inner_correlation(\n",
    "            queries,\n",
    "            keys,\n",
    "            values,\n",
    "            attn_mask\n",
    "        )\n",
    "        out = out.view(B, L, -1)\n",
    "\n",
    "        return self.out_projection(out), attn\n",
    "    \n",
    "\n",
    "class LayerNorm(nn.Module):\n",
    "    \"\"\"\n",
    "    Special designed layernorm for the seasonal part\n",
    "    \"\"\"\n",
    "    def __init__(self, channels):\n",
    "        super(LayerNorm, self).__init__()\n",
    "        self.layernorm = nn.LayerNorm(channels)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x_hat = self.layernorm(x)\n",
    "        bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)\n",
    "        return x_hat - bias\n",
    "\n",
    "\n",
    "class MovingAvg(nn.Module):\n",
    "    \"\"\"\n",
    "    Moving average block to highlight the trend of time series\n",
    "    \"\"\"\n",
    "    def __init__(self, kernel_size, stride):\n",
    "        super(MovingAvg, self).__init__()\n",
    "        self.kernel_size = kernel_size\n",
    "        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # padding on the both ends of time series\n",
    "        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n",
    "        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n",
    "        x = torch.cat([front, x, end], dim=1)\n",
    "        x = self.avg(x.permute(0, 2, 1))\n",
    "        x = x.permute(0, 2, 1)\n",
    "        return x\n",
    "\n",
    "\n",
    "class SeriesDecomp(nn.Module):\n",
    "    \"\"\"\n",
    "    Series decomposition block\n",
    "    \"\"\"\n",
    "    def __init__(self, kernel_size):\n",
    "        super(SeriesDecomp, self).__init__()\n",
    "        self.MovingAvg = MovingAvg(kernel_size, stride=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        moving_mean = self.MovingAvg(x)\n",
    "        res = x - moving_mean\n",
    "        return res, moving_mean\n",
    "\n",
    "\n",
    "class EncoderLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    Autoformer encoder layer with the progressive decomposition architecture\n",
    "    \"\"\"\n",
    "    def __init__(self, attention, hidden_size, conv_hidden_size=None, MovingAvg=25, dropout=0.1, activation=\"relu\"):\n",
    "        super(EncoderLayer, self).__init__()\n",
    "        conv_hidden_size = conv_hidden_size or 4 * hidden_size\n",
    "        self.attention = attention\n",
    "        self.conv1 = nn.Conv1d(in_channels=hidden_size, out_channels=conv_hidden_size, kernel_size=1, bias=False)\n",
    "        self.conv2 = nn.Conv1d(in_channels=conv_hidden_size, out_channels=hidden_size, kernel_size=1, bias=False)\n",
    "        self.decomp1 = SeriesDecomp(MovingAvg)\n",
    "        self.decomp2 = SeriesDecomp(MovingAvg)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.activation = F.relu if activation == \"relu\" else F.gelu\n",
    "\n",
    "    def forward(self, x, attn_mask=None):\n",
    "        new_x, attn = self.attention(\n",
    "            x, x, x,\n",
    "            attn_mask=attn_mask\n",
    "        )\n",
    "        x = x + self.dropout(new_x)\n",
    "        x, _ = self.decomp1(x)\n",
    "        y = x\n",
    "        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n",
    "        y = self.dropout(self.conv2(y).transpose(-1, 1))\n",
    "        res, _ = self.decomp2(x + y)\n",
    "        return res, attn\n",
    "\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    \"\"\"\n",
    "    Autoformer encoder\n",
    "    \"\"\"\n",
    "    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.attn_layers = nn.ModuleList(attn_layers)\n",
    "        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None\n",
    "        self.norm = norm_layer\n",
    "\n",
    "    def forward(self, x, attn_mask=None):\n",
    "        attns = []\n",
    "        if self.conv_layers is not None:\n",
    "            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):\n",
    "                x, attn = attn_layer(x, attn_mask=attn_mask)\n",
    "                x = conv_layer(x)\n",
    "                attns.append(attn)\n",
    "            x, attn = self.attn_layers[-1](x)\n",
    "            attns.append(attn)\n",
    "        else:\n",
    "            for attn_layer in self.attn_layers:\n",
    "                x, attn = attn_layer(x, attn_mask=attn_mask)\n",
    "                attns.append(attn)\n",
    "\n",
    "        if self.norm is not None:\n",
    "            x = self.norm(x)\n",
    "\n",
    "        return x, attns\n",
    "\n",
    "\n",
    "class DecoderLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    Autoformer decoder layer with the progressive decomposition architecture\n",
    "    \"\"\"\n",
    "    def __init__(self, self_attention, cross_attention, hidden_size, c_out, conv_hidden_size=None,\n",
    "                 MovingAvg=25, dropout=0.1, activation=\"relu\"):\n",
    "        super(DecoderLayer, self).__init__()\n",
    "        conv_hidden_size = conv_hidden_size or 4 * hidden_size\n",
    "        self.self_attention = self_attention\n",
    "        self.cross_attention = cross_attention\n",
    "        self.conv1 = nn.Conv1d(in_channels=hidden_size, out_channels=conv_hidden_size, kernel_size=1, bias=False)\n",
    "        self.conv2 = nn.Conv1d(in_channels=conv_hidden_size, out_channels=hidden_size, kernel_size=1, bias=False)\n",
    "        self.decomp1 = SeriesDecomp(MovingAvg)\n",
    "        self.decomp2 = SeriesDecomp(MovingAvg)\n",
    "        self.decomp3 = SeriesDecomp(MovingAvg)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.projection = nn.Conv1d(in_channels=hidden_size, out_channels=c_out, kernel_size=3, stride=1, padding=1,\n",
    "                                    padding_mode='circular', bias=False)\n",
    "        self.activation = F.relu if activation == \"relu\" else F.gelu\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None):\n",
    "        x = x + self.dropout(self.self_attention(\n",
    "            x, x, x,\n",
    "            attn_mask=x_mask\n",
    "        )[0])\n",
    "        x, trend1 = self.decomp1(x)\n",
    "        x = x + self.dropout(self.cross_attention(\n",
    "            x, cross, cross,\n",
    "            attn_mask=cross_mask\n",
    "        )[0])\n",
    "        x, trend2 = self.decomp2(x)\n",
    "        y = x\n",
    "        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n",
    "        y = self.dropout(self.conv2(y).transpose(-1, 1))\n",
    "        x, trend3 = self.decomp3(x + y)\n",
    "\n",
    "        residual_trend = trend1 + trend2 + trend3\n",
    "        residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)\n",
    "        return x, residual_trend\n",
    "\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    \"\"\"\n",
    "    Autoformer decoder\n",
    "    \"\"\"\n",
    "    def __init__(self, layers, norm_layer=None, projection=None):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.norm = norm_layer\n",
    "        self.projection = projection\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):\n",
    "        for layer in self.layers:\n",
    "            x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)\n",
    "            trend = trend + residual_trend\n",
    "\n",
    "        if self.norm is not None:\n",
    "            x = self.norm(x)\n",
    "\n",
    "        if self.projection is not None:\n",
    "            x = self.projection(x)\n",
    "        return x, trend"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Autoformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Autoformer(BaseWindows):\n",
    "    \"\"\" Autoformer\n",
    "\n",
    "    The Autoformer model tackles the challenge of finding reliable dependencies on intricate temporal patterns of long-horizon forecasting.\n",
    "\n",
    "    The architecture has the following distinctive features:\n",
    "    - In-built progressive decomposition in trend and seasonal compontents based on a moving average filter.\n",
    "    - Auto-Correlation mechanism that discovers the period-based dependencies by\n",
    "    calculating the autocorrelation and aggregating similar sub-series based on the periodicity.\n",
    "    - Classic encoder-decoder proposed by Vaswani et al. (2017) with a multi-head attention mechanism.\n",
    "\n",
    "    The Autoformer model utilizes a three-component approach to define its embedding:\n",
    "    - It employs encoded autoregressive features obtained from a convolution network.\n",
    "    - Absolute positional embeddings obtained from calendar features are utilized.\n",
    "\n",
    "    *Parameters:*<br>\n",
    "    `h`: int, forecast horizon.<br>\n",
    "    `input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses all history.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n",
    "\t`decoder_input_size_multiplier`: float = 0.5, .<br>\n",
    "    `hidden_size`: int=128, units of embeddings and encoders.<br>\n",
    "    `n_head`: int=4, controls number of multi-head's attention.<br>\n",
    "    `dropout`: float (0, 1), dropout throughout Autoformer architecture.<br>\n",
    "\t`factor`: int=3, Probsparse attention factor.<br>\n",
    "\t`conv_hidden_size`: int=32, channels of the convolutional encoder.<br>\n",
    "\t`activation`: str=`GELU`, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid', 'GELU'].<br>\n",
    "    `encoder_layers`: int=2, number of layers for the TCN encoder.<br>\n",
    "    `decoder_layers`: int=1, number of layers for the MLP decoder.<br>\n",
    "    `distil`: bool = True, wether the Autoformer decoder uses bottlenecks.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
    "\n",
    "\t*References*<br>\n",
    "\t- [Wu, Haixu, Jiehui Xu, Jianmin Wang, and Mingsheng Long. \"Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting\"](https://proceedings.neurips.cc/paper/2021/hash/bcc0d400288793e8bdcd7c19a8ac0c2b-Abstract.html)<br>\n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = 'windows'\n",
    "\n",
    "    def __init__(self,\n",
    "                 h: int, \n",
    "                 input_size: int,\n",
    "                 stat_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 futr_exog_list = None,\n",
    "                 exclude_insample_y = False,\n",
    "                 decoder_input_size_multiplier: float = 0.5,\n",
    "                 hidden_size: int = 128, \n",
    "                 dropout: float = 0.05,\n",
    "                 factor: int = 3,\n",
    "                 n_head: int = 4,\n",
    "                 conv_hidden_size: int = 32,\n",
    "                 activation: str = 'gelu',\n",
    "                 encoder_layers: int = 2, \n",
    "                 decoder_layers: int = 1,\n",
    "                 MovingAvg_window: int = 25,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 5000,\n",
    "                 learning_rate: float = 1e-4,\n",
    "                 num_lr_decays: int = -1,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size = 1024,\n",
    "                 inference_windows_batch_size = 1024,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'identity',\n",
    "                 random_seed: int = 1,\n",
    "                 num_workers_loader: int = 0,\n",
    "                 drop_last_loader: bool = False,\n",
    "                 **trainer_kwargs):\n",
    "        super(Autoformer, self).__init__(h=h,\n",
    "                                       input_size=input_size,\n",
    "                                       hist_exog_list=hist_exog_list,\n",
    "                                       stat_exog_list=stat_exog_list,\n",
    "                                       futr_exog_list = futr_exog_list,\n",
    "                                       exclude_insample_y = exclude_insample_y,\n",
    "                                       loss=loss,\n",
    "                                       valid_loss=valid_loss,\n",
    "                                       max_steps=max_steps,\n",
    "                                       learning_rate=learning_rate,\n",
    "                                       num_lr_decays=num_lr_decays,\n",
    "                                       early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                       val_check_steps=val_check_steps,\n",
    "                                       batch_size=batch_size,\n",
    "                                       windows_batch_size=windows_batch_size,\n",
    "                                       valid_batch_size=valid_batch_size,\n",
    "                                       inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                       start_padding_enabled = start_padding_enabled,\n",
    "                                       step_size=step_size,\n",
    "                                       scaler_type=scaler_type,\n",
    "                                       num_workers_loader=num_workers_loader,\n",
    "                                       drop_last_loader=drop_last_loader,\n",
    "                                       random_seed=random_seed,\n",
    "                                       **trainer_kwargs)\n",
    "\n",
    "        # Architecture\n",
    "        self.futr_input_size = len(self.futr_exog_list)\n",
    "        self.hist_input_size = len(self.hist_exog_list)\n",
    "        self.stat_input_size = len(self.stat_exog_list)\n",
    "\n",
    "        if self.stat_input_size > 0:\n",
    "            raise Exception('Autoformer does not support static variables yet')\n",
    "        \n",
    "        if self.hist_input_size > 0:\n",
    "            raise Exception('Autoformer does not support historical variables yet')\n",
    "\n",
    "        self.label_len = int(np.ceil(input_size * decoder_input_size_multiplier))\n",
    "        if (self.label_len >= input_size) or (self.label_len <= 0):\n",
    "            raise Exception(f'Check decoder_input_size_multiplier={decoder_input_size_multiplier}, range (0,1)')\n",
    "\n",
    "        if activation not in ['relu', 'gelu']:\n",
    "            raise Exception(f'Check activation={activation}')\n",
    "        \n",
    "        self.c_out = self.loss.outputsize_multiplier\n",
    "        self.output_attention = False\n",
    "        self.enc_in = 1 \n",
    "        self.dec_in = 1\n",
    "\n",
    "        # Decomposition\n",
    "        self.decomp = SeriesDecomp(MovingAvg_window)\n",
    "\n",
    "        # Embedding\n",
    "        self.enc_embedding = DataEmbedding(c_in=self.enc_in,\n",
    "                                           exog_input_size=self.hist_input_size,\n",
    "                                           hidden_size=hidden_size, \n",
    "                                           pos_embedding=False,\n",
    "                                           dropout=dropout)\n",
    "        self.dec_embedding = DataEmbedding(self.dec_in,\n",
    "                                           exog_input_size=self.hist_input_size,\n",
    "                                           hidden_size=hidden_size, \n",
    "                                           pos_embedding=False,\n",
    "                                           dropout=dropout)\n",
    "\n",
    "        # Encoder\n",
    "        self.encoder = Encoder(\n",
    "            [\n",
    "                EncoderLayer(\n",
    "                    AutoCorrelationLayer(\n",
    "                        AutoCorrelation(False, factor,\n",
    "                                      attention_dropout=dropout,\n",
    "                                      output_attention=self.output_attention),\n",
    "                        hidden_size, n_head),\n",
    "                    hidden_size=hidden_size,\n",
    "                    conv_hidden_size=conv_hidden_size,\n",
    "                    MovingAvg=MovingAvg_window,\n",
    "                    dropout=dropout,\n",
    "                    activation=activation\n",
    "                ) for l in range(encoder_layers)\n",
    "            ],\n",
    "            norm_layer=LayerNorm(hidden_size)\n",
    "        )\n",
    "        # Decoder\n",
    "        self.decoder = Decoder(\n",
    "            [\n",
    "                DecoderLayer(\n",
    "                    AutoCorrelationLayer(\n",
    "                        AutoCorrelation(True, factor, attention_dropout=dropout, output_attention=False),\n",
    "                        hidden_size, n_head),\n",
    "                    AutoCorrelationLayer(\n",
    "                        AutoCorrelation(False, factor, attention_dropout=dropout, output_attention=False),\n",
    "                        hidden_size, n_head),\n",
    "                    hidden_size=hidden_size,\n",
    "                    c_out=self.c_out,\n",
    "                    conv_hidden_size=conv_hidden_size,\n",
    "                    MovingAvg=MovingAvg_window,\n",
    "                    dropout=dropout,\n",
    "                    activation=activation,\n",
    "                )\n",
    "                for l in range(decoder_layers)\n",
    "            ],\n",
    "            norm_layer=LayerNorm(hidden_size),\n",
    "            projection=nn.Linear(hidden_size, self.c_out, bias=True)\n",
    "        )\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "        # Parse windows_batch\n",
    "        insample_y    = windows_batch['insample_y']\n",
    "        #insample_mask = windows_batch['insample_mask']\n",
    "        #hist_exog     = windows_batch['hist_exog']\n",
    "        #stat_exog     = windows_batch['stat_exog']\n",
    "        futr_exog     = windows_batch['futr_exog']\n",
    "\n",
    "        # Parse inputs\n",
    "        insample_y = insample_y.unsqueeze(-1) # [Ws,L,1]\n",
    "        if self.futr_input_size > 0:\n",
    "            x_mark_enc = futr_exog[:,:self.input_size,:]\n",
    "            x_mark_dec = futr_exog[:,-(self.label_len+self.h):,:]\n",
    "        else:\n",
    "            x_mark_enc = None\n",
    "            x_mark_dec = None\n",
    "\n",
    "        x_dec = torch.zeros(size=(len(insample_y),self.h,1)).to(insample_y.device)\n",
    "        x_dec = torch.cat([insample_y[:,-self.label_len:,:], x_dec], dim=1)\n",
    "\n",
    "        # decomp init\n",
    "        mean = torch.mean(insample_y, dim=1).unsqueeze(1).repeat(1, self.h, 1)\n",
    "        zeros = torch.zeros([x_dec.shape[0], self.h, x_dec.shape[2]], device=insample_y.device)\n",
    "        seasonal_init, trend_init = self.decomp(insample_y)\n",
    "        # decoder input\n",
    "        trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)\n",
    "        seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)\n",
    "        # enc\n",
    "        enc_out = self.enc_embedding(insample_y, x_mark_enc)\n",
    "        enc_out, attns = self.encoder(enc_out, attn_mask=None)\n",
    "        # dec\n",
    "        dec_out = self.dec_embedding(seasonal_init, x_mark_dec)\n",
    "        seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None,\n",
    "                                                 trend=trend_init)\n",
    "        # final\n",
    "        dec_out = trend_part + seasonal_part\n",
    "\n",
    "        forecast = self.loss.domain_map(dec_out[:, -self.h:])\n",
    "        return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(Autoformer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(Autoformer.fit, name='Autoformer.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(Autoformer.predict, name='Autoformer.predict')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import MLP\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
    "\n",
    "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = Autoformer(h=12,\n",
    "                 input_size=24,\n",
    "                 hidden_size = 16,\n",
    "                 conv_hidden_size = 32,\n",
    "                 n_head=2,\n",
    "                 loss=MAE(),\n",
    "                 futr_exog_list=calendar_cols,\n",
    "                 scaler_type='robust',\n",
    "                 learning_rate=1e-3,\n",
    "                 max_steps=300,\n",
    "                 val_check_steps=50,\n",
    "                 early_stop_patience_steps=2)\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[model],\n",
    "    freq='M'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "forecasts = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "if model.loss.is_distribution_output:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['Autoformer-median'], c='blue', label='median')\n",
    "    plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                    y1=plot_df['Autoformer-lo-90'][-12:].values, \n",
    "                    y2=plot_df['Autoformer-hi-90'][-12:].values,\n",
    "                    alpha=0.4, label='level 90')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "else:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['Autoformer'], c='blue', label='Forecast')\n",
    "    plt.legend()\n",
    "    plt.grid()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
